from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tool import common
import threading
import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import optimizers
from tensorflow.keras.callbacks import LearningRateScheduler
from tensorflow.keras.callbacks import TensorBoard
from tensorflow.keras.callbacks import ModelCheckpoint


def tic_train_type_thread():

    def tf_train():

        model = keras.Sequential(
            [
                layers.Conv2D(filters=64, kernel_size=(3, 3), strides=(1, 1), padding="SAME", activation="relu",
                              input_shape=[224, 224, 1]),
                layers.Conv2D(filters=64, kernel_size=(3, 3), strides=(1, 1), padding="SAME", activation="relu"),
                layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2), padding="SAME"),
                layers.Conv2D(filters=128, kernel_size=(3, 3), strides=(1, 1), padding="SAME", activation="relu"),
                layers.Conv2D(filters=128, kernel_size=(3, 3), strides=(1, 1), padding="SAME", activation="relu"),
                layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2), padding="SAME"),
                layers.Conv2D(filters=256, kernel_size=(3, 3), strides=(1, 1), padding="SAME", activation="relu"),
                layers.Conv2D(filters=256, kernel_size=(3, 3), strides=(1, 1), padding="SAME", activation="relu"),
                layers.Conv2D(filters=256, kernel_size=(3, 3), strides=(1, 1), padding="SAME", activation="relu"),
                layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2), padding="SAME"),
                layers.Conv2D(filters=512, kernel_size=(3, 3), strides=(1, 1), padding="SAME", activation="relu"),
                layers.Conv2D(filters=512, kernel_size=(3, 3), strides=(1, 1), padding="SAME", activation="relu"),
                layers.Conv2D(filters=512, kernel_size=(3, 3), strides=(1, 1), padding="SAME", activation="relu"),
                layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2), padding="SAME"),
                layers.Conv2D(filters=512, kernel_size=(3, 3), strides=(1, 1), padding="SAME", activation="relu"),
                layers.Conv2D(filters=512, kernel_size=(3, 3), strides=(1, 1), padding="SAME", activation="relu"),
                layers.Conv2D(filters=512, kernel_size=(3, 3), strides=(1, 1), padding="SAME", activation="relu"),
                layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2), padding="SAME"),
                layers.Flatten(),
                layers.Dense(512, activation="relu"),
                layers.Dropout(0.5),
                layers.Dense(256, activation="relu"),
                layers.Dropout(0.5),
                layers.Dense(17, activation='softmax', name='output')
            ])
        
        print(model.output.name)
        
        
        learning_rate = 1e-3 * 0.8;
        epoch_num = 28
        
        
        def scheduler(epoch):
            if epoch < epoch_num * 0.4:
                return learning_rate
            if epoch < epoch_num * 0.7:
                return learning_rate * 0.1
            return learning_rate * 0.05
        
        
        sgd = optimizers.SGD(lr=learning_rate, momentum=0.9, nesterov=True)
        change_lr = LearningRateScheduler(scheduler)
        model.compile(sgd, loss="sparse_categorical_crossentropy", metrics=["accuracy"])
        
        logdir = os.path.join("tflog")
        tensorboard = TensorBoard(log_dir=logdir, write_graph=True)
        
        earlysStop = tf.keras.callbacks.EarlyStopping(monitor='accuracy', patience=2)
        
        model.fit(common.dataProcess_tic_type.get_dataset_train(), epochs=epoch_num, callbacks=[change_lr, tensorboard, earlysStop])
        
        model.evaluate(common.dataProcess_tic_type.get_dataset_test())
        
        model.predict(common.dataProcess_tic_type.get_dataset_test())
        
        model.save('./model/tic_0324.h5')

    th = threading.Thread(target=tf_train)
    th.start()



def axle_train_type_thread():

    def tf_train():

        model = keras.Sequential(
            [
                layers.Conv2D(filters=64, kernel_size=(3, 3), strides=(1, 1), padding="SAME", activation="relu",
                              input_shape=[224, 224, 1]),
                layers.Conv2D(filters=64, kernel_size=(3, 3), strides=(1, 1), padding="SAME", activation="relu"),
                layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2), padding="SAME"),
                layers.Conv2D(filters=128, kernel_size=(3, 3), strides=(1, 1), padding="SAME", activation="relu"),
                layers.Conv2D(filters=128, kernel_size=(3, 3), strides=(1, 1), padding="SAME", activation="relu"),
                layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2), padding="SAME"),
                layers.Conv2D(filters=256, kernel_size=(3, 3), strides=(1, 1), padding="SAME", activation="relu"),
                layers.Conv2D(filters=256, kernel_size=(3, 3), strides=(1, 1), padding="SAME", activation="relu"),
                layers.Conv2D(filters=256, kernel_size=(3, 3), strides=(1, 1), padding="SAME", activation="relu"),
                layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2), padding="SAME"),
                layers.Conv2D(filters=512, kernel_size=(3, 3), strides=(1, 1), padding="SAME", activation="relu"),
                layers.Conv2D(filters=512, kernel_size=(3, 3), strides=(1, 1), padding="SAME", activation="relu"),
                layers.Conv2D(filters=512, kernel_size=(3, 3), strides=(1, 1), padding="SAME", activation="relu"),
                layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2), padding="SAME"),
                layers.Conv2D(filters=512, kernel_size=(3, 3), strides=(1, 1), padding="SAME", activation="relu"),
                layers.Conv2D(filters=512, kernel_size=(3, 3), strides=(1, 1), padding="SAME", activation="relu"),
                layers.Conv2D(filters=512, kernel_size=(3, 3), strides=(1, 1), padding="SAME", activation="relu"),
                layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2), padding="SAME"),
                layers.Flatten(),
                layers.Dense(512, activation="relu"),
                layers.Dropout(0.5),
                layers.Dense(256, activation="relu"),
                layers.Dropout(0.5),
                layers.Dense(8, activation='softmax', name='output')
            ])

        print(model.output.name)

        learning_rate = 1e-3 * 0.8
        epoch_num = 26

        def scheduler(epoch):
            if epoch < epoch_num * 0.4:
                return learning_rate
            if epoch < epoch_num * 0.7:
                return learning_rate * 0.1
            return learning_rate * 0.02


        sgd = optimizers.SGD(lr=learning_rate, momentum=0.9, nesterov=True)
        change_lr = LearningRateScheduler(scheduler)
        model.compile(sgd, loss="sparse_categorical_crossentropy", metrics=["accuracy"])

        logdir = os.path.join("tflog")
        tensorboard = TensorBoard(log_dir=logdir, write_graph=True, histogram_freq=1)

        filepath = "./model/axle_0708_auto.h5"
        checkpoint = ModelCheckpoint(filepath, monitor='accuracy', verbose=1, save_best_only=True, mode='max')

        model.fit(common.dataProcess_axle_type.get_dataset_train(), epochs=epoch_num,
                  callbacks=[change_lr, tensorboard, checkpoint])
        print('train over')

        model.evaluate(common.dataProcess_axle_type.get_dataset_test(), callbacks=[tensorboard])
        print('evaluate over')

        model.predict(common.dataProcess_axle_type.get_dataset_test(), callbacks=[tensorboard])
        print('predict over')

        model.save('./model/axle_0708.h5')

        print('save over')

    th = threading.Thread(target=tf_train)
    th.start()